#include "stdafx.h"
#include "mysql_module.h"

// ==========================================================//
// export function, don't forget AsyncMysqlSpace namespace
namespace SMysqlSpace {
	IMysqlModule* STDCALL GetMysqlModule() {
		if (nullptr == CMysqlModule::Instance()) {
			// create instance
			if (!CMysqlModule::CreateInstance()) {
				return nullptr;
			}

			// init
			if (!CMysqlModule::Instance()->Init()) {
				CMysqlModule::DestroyInstance();
				return nullptr;
			}
		}
		CMysqlModule::Instance()->AddRef();
		return CMysqlModule::Instance();
	}

	void STDCALL SetMysqlLogger(IGNLogger* poLogger, uint32 dwLevel) {
		assert(poLogger && "logger is null");
		gPoLogger = poLogger;

		// level 
		gLoggerLevel = dwLevel;
	}
}
// =========================================================//

namespace SMysqlSpace {
	CMysqlModule* CMysqlModule::m_poInstance = nullptr;
	CMysqlModule::CMysqlModule() : m_ref(0) {
	}

	CMysqlModule::~CMysqlModule() {
		for (auto& e : m_group2Connection) {
			auto& vecConnect = e.second;
			for (size_t i = 0; i < vecConnect.size(); i++) {
				vecConnect[i]->Release();
			}
			vecConnect.clear();
		}
		m_group2Connection.clear();
	}

	bool CMysqlModule::Init() {
		m_ref = 0;
		m_listCommand.clear();
		m_group2Connection.clear();
		return m_recordsetPool.Init(initSize, growSize);
	}

	bool CMysqlModule::Connect(int group,
		const SConnectMysqlInfo& connectInfo, int connectNum
	) {
		if (connectNum <= 0) {
			connectNum = 1;
		}

		// find group info, return false.
		if (m_group2Connection.find(group) != m_group2Connection.end()) {
			return false;
		}

		for (int i = 0; i < connectNum; i++) {
			CMysqlConnection* poConnect = new CMysqlConnection();
			if (nullptr == poConnect) {
				Critical("[mysql] get connection object error");
				return false;
			}

			if (!poConnect->Connect(connectInfo)) {
				Critical("[mysql] connect %d group error", group);
				delete poConnect;
				return false;
			}

			// start a thread to deal command
			poConnect->start();

			// add to vector
			m_group2Connection[group].push_back(poConnect);
		}

		return true;
	}

	IMysqlConnection* CMysqlModule::FindConnection(int group, int index) {
		auto it = m_group2Connection.find(group);
		if (it == m_group2Connection.end()) {
			return nullptr;
		}

		CConnectVec& vecConnect = it->second;
		if (index >= (int)vecConnect.size() || index < 0) {
			return nullptr;
		}

		return vecConnect[index];
	}

	bool CMysqlModule::AddCommand(int group, int index,
		ICommand* poCommand, bool bHighPriority
	) {
		if (nullptr == poCommand) {
			return false;
		}
		IMysqlConnection* poConnect = this->FindConnection(group, index);
		if (nullptr == poConnect) {
			Critical("[mysql] can not find %d:%d connection", group, index);
			return false;
		}

		CMysqlConnection* poRealConnect = (CMysqlConnection*)poConnect;
		if (nullptr == poRealConnect) {
			Critical("[mysql] find null object");
			return false;
		}
		assert(poRealConnect && "invalid connect");
		poRealConnect->AddCommand(poCommand, bHighPriority);

		return true;
	}

	int CMysqlModule::EscapeString(int group,
		const char* from, int fromLen, char* to, int toLen
	) {
		IMysqlConnection* poConnect = this->FindConnection(group, 0);
		if (nullptr == poConnect) {
			return -1;
		}
		return poConnect->EscapeString(from, fromLen, to, toLen);
	}

	bool CMysqlModule::Run(int count) {
		if (count < 0) count = 0;

		do {
			ICommand* poCommand = nullptr;
			m_locker.Lock();
			if (!m_listCommand.empty()) {
				poCommand = m_listCommand.front();
				m_listCommand.pop_front();
				m_locker.Unlock();
			}
			else {
				m_locker.Unlock();
				return false;
			}

			if (nullptr == poCommand) {
				return false;
			}

			poCommand->OnExecuted();
			poCommand->Release();
		} while (count-- != 0);

		return true;
	}

	void CMysqlModule::AddExecutedCommand(ICommand* poCommand) {
		if (nullptr == poCommand) return;

		CAutoLock<CLocker> oAutoLocker(m_locker);
		m_listCommand.push_back(poCommand);
	}

	int CMysqlModule::QueryWithoutResult(int group, int index, const char* sql) {
		IMysqlConnection* poConnect = this->FindConnection(group, index);
		if (nullptr == poConnect) {
			return MYSQL_EXECUTE_NO_INDEX;
		}
		return poConnect->QueryWithoutResult(sql);
	}

	int CMysqlModule::QueryWithResult(int group, int index,
		const char* sql, IMysqlRecordset** ppRes
	) {
		IMysqlConnection* poConnect = this->FindConnection(group, index);
		if (nullptr == poConnect) {
			return MYSQL_EXECUTE_PARA_ERROR;
		}
		return poConnect->QueryWithResult(sql, ppRes);
	}

	void CMysqlModule::Release() {
		assert(m_ref > 0);
		DecRef();
		if (m_ref <= 0) {
			DestroyInstance();
		}
	}

	bool STDCALL CMysqlModule::Close(int group) {
		auto it = m_group2Connection.find(group);
		if (it == m_group2Connection.end()) {
			return false;
		}

		// close all connection.
		auto& vecConnect = it->second;
		for (size_t i = 0; i < vecConnect.size(); i++) {
			vecConnect[i]->Release();
		}

		// remove it from map.
		m_group2Connection.erase(it);

		return true;
	}

	const char* STDCALL CMysqlModule::GetLastError(int group, int index) {
		IMysqlConnection* poConnect = this->FindConnection(group, index);
		if (nullptr == poConnect) {
			return "";
		} else {
			return poConnect->GetLastError();
		}
	}

	int STDCALL CMysqlModule::GetLastErrorNo(int group, int index) {
		IMysqlConnection* poConnect = this->FindConnection(group, index);
		if (nullptr == poConnect) {
			return -1;
		}
		return poConnect->GetLastErrorNo();
	}
}
